{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sklearn import preprocessing\n",
    "\n",
    "from gnnad.graphanomaly import GNNAD\n",
    "from gnnad.plot import plot_test_anomalies, plot_predictions, plot_sensor_error_scores\n",
    "\n",
    "def normalise(X, scaler_fn):\n",
    "    scaler = scaler_fn.fit(X)\n",
    "    return pd.DataFrame(scaler.transform(X), index=X.index, columns = X.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read in training data\n",
    "X_train = pd.read_csv('./examples/herbert_train.csv', index_col=0)\n",
    "X_train.index = pd.to_datetime(X_train.index)\n",
    "\n",
    "# read in test data\n",
    "X_tmp = pd.read_csv('./examples/herbert_test.csv', index_col=0)\n",
    "X_tmp.index = pd.to_datetime(X_tmp.index)\n",
    "X_test = X_tmp.iloc[:,:8]\n",
    "y_test = X_tmp.iloc[:,8:].apply(any, axis=1)\n",
    "\n",
    "# normalise\n",
    "X_test = normalise(X_test, preprocessing.StandardScaler())\n",
    "X_train = normalise(X_train, preprocessing.StandardScaler())\n",
    "\n",
    "# create ANOOMS dict for plotting\n",
    "ANOMS = {'type1': {}}\n",
    "X_test_anoms = X_tmp.iloc[:,8:]\n",
    "\n",
    "for i in range(len(X_test_anoms.columns)):\n",
    "    anom_col_name = X_test_anoms.columns[i]\n",
    "    sensor_col_name = X_test.columns[i]\n",
    "    anom_idxs = X_test_anoms[anom_col_name][X_test_anoms[anom_col_name]].index\n",
    "\n",
    "    if len(anom_idxs) > 0:\n",
    "        ANOMS['type1'][sensor_col_name] = anom_idxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot input data\n",
    "plot_test_anomalies(X_test, ANOMS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# run model\n",
    "model = GNNAD(threshold_type=\"max_validation\", topk=6, slide_win=200)\n",
    "fitted_model = model.fit(X_train, X_test, y_test)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model summary\n",
    "fitted_model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# GDN+, sensor thresholds\n",
    "preds = fitted_model.sensor_threshold_preds(tau = 99)\n",
    "fitted_model.print_eval_metrics(preds)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot predictions\n",
    "plot_predictions(fitted_model, X_test, ANOMS, preds = preds, figsize=(20, 20))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_sensor_error_scores(fitted_model, X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gdn_old",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.2"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
